package self_test

import (
	"bytes"
	"context"
	"crypto/tls"
	"errors"
	"fmt"
	"io"
	mrand "math/rand/v2"
	"net"
	"sync/atomic"
	"testing"
	"time"

	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/internal/protocol"
	"github.com/quic-go/quic-go/internal/synctest"
	"github.com/quic-go/quic-go/qlog"
	"github.com/quic-go/quic-go/qlogwriter"
	"github.com/quic-go/quic-go/testutils/simnet"

	"github.com/stretchr/testify/require"
)

func requireIdleTimeoutError(t *testing.T, err error) {
	t.Helper()

	require.Error(t, err)
	var idleTimeoutErr *quic.IdleTimeoutError
	require.ErrorAs(t, err, &idleTimeoutErr)
	require.True(t, idleTimeoutErr.Timeout())
	var nerr net.Error
	require.True(t, errors.As(err, &nerr))
	require.True(t, nerr.Timeout())
}

func TestHandshakeIdleTimeout(t *testing.T) {
	t.Run("Dial", func(t *testing.T) {
		testHandshakeIdleTimeout(t, quic.Dial)
	})

	t.Run("DialEarly", func(t *testing.T) {
		testHandshakeIdleTimeout(t, quic.DialEarly)
	})
}

func testHandshakeIdleTimeout(t *testing.T, dialFn func(context.Context, net.PacketConn, net.Addr, *tls.Config, *quic.Config) (*quic.Conn, error)) {
	synctest.Test(t, func(t *testing.T) {
		const handshakeIdleTimeout = 3 * time.Second

		clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, time.Millisecond)
		defer closeFn(t)

		errChan := make(chan error, 1)
		start := time.Now()
		go func() {
			_, err := dialFn(
				context.Background(),
				clientPacketConn,
				serverPacketConn.LocalAddr(),
				getTLSClientConfig(),
				getQuicConfig(&quic.Config{HandshakeIdleTimeout: handshakeIdleTimeout}),
			)
			errChan <- err
		}()
		select {
		case err := <-errChan:
			requireIdleTimeoutError(t, err)
			require.Equal(t, handshakeIdleTimeout, time.Since(start))
		case <-time.After(5 * time.Second):
			t.Fatal("timeout waiting for dial error")
		}
	})
}

func TestIdleTimeout(t *testing.T) {
	synctest.Test(t, func(t *testing.T) {
		const idleTimeout = 20 * time.Second

		var drop atomic.Bool
		clientPacketConn, serverPacketConn, closeFn := newSimnetLinkWithRouter(t,
			time.Millisecond,
			&droppingRouter{Drop: func(p simnet.Packet) bool { return drop.Load() }},
		)
		defer closeFn(t)

		server, err := quic.Listen(
			serverPacketConn,
			getTLSConfig(),
			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
		)
		require.NoError(t, err)
		defer server.Close()

		conn, err := quic.Dial(
			context.Background(),
			clientPacketConn,
			serverPacketConn.LocalAddr(),
			getTLSClientConfig(),
			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true, MaxIdleTimeout: idleTimeout}),
		)
		require.NoError(t, err)

		serverConn, err := server.Accept(context.Background())
		require.NoError(t, err)
		str, err := serverConn.OpenStream()
		require.NoError(t, err)
		_, err = str.Write([]byte("foobar"))
		require.NoError(t, err)

		serverStart := time.Now()

		strIn, err := conn.AcceptStream(context.Background())
		require.NoError(t, err)
		strOut, err := conn.OpenStream()
		require.NoError(t, err)
		_, err = strIn.Read(make([]byte, 6))
		require.NoError(t, err)

		clientStart := time.Now()

		drop.Store(true)

		select {
		case <-serverConn.Context().Done():
			took := time.Since(serverStart)
			require.GreaterOrEqual(t, took, idleTimeout)
			t.Logf("server connection timed out after %s (idle timeout: %s)", took, idleTimeout)
		case <-time.After(2 * idleTimeout):
			t.Fatal("timeout waiting for idle timeout")
		}

		select {
		case <-conn.Context().Done():
			took := time.Since(clientStart)
			require.GreaterOrEqual(t, took, idleTimeout)
			t.Logf("client connection timed out after %s (idle timeout: %s)", took, idleTimeout)
		case <-time.After(2 * idleTimeout):
			t.Fatal("timeout waiting for idle timeout")
		}

		_, err = strIn.Write([]byte("test"))
		requireIdleTimeoutError(t, err)
		_, err = strIn.Read([]byte{0})
		requireIdleTimeoutError(t, err)
		_, err = strOut.Write([]byte("test"))
		requireIdleTimeoutError(t, err)
		_, err = strOut.Read([]byte{0})
		requireIdleTimeoutError(t, err)
		_, err = conn.OpenStream()
		requireIdleTimeoutError(t, err)
		_, err = conn.OpenUniStream()
		requireIdleTimeoutError(t, err)
		_, err = conn.AcceptStream(context.Background())
		requireIdleTimeoutError(t, err)
		_, err = conn.AcceptUniStream(context.Background())
		requireIdleTimeoutError(t, err)
	})
}

func TestKeepAlive(t *testing.T) {
	synctest.Test(t, func(t *testing.T) {
		const idleTimeout = 4 * time.Second

		var drop atomic.Bool
		clientPacketConn, serverPacketConn, closeFn := newSimnetLinkWithRouter(t,
			time.Millisecond,
			&droppingRouter{Drop: func(p simnet.Packet) bool { return drop.Load() }},
		)
		defer closeFn(t)

		server, err := quic.Listen(
			serverPacketConn,
			getTLSConfig(),
			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
		)
		require.NoError(t, err)
		defer server.Close()

		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		defer cancel()
		conn, err := quic.Dial(
			ctx,
			clientPacketConn,
			serverPacketConn.LocalAddr(),
			getTLSClientConfig(),
			getQuicConfig(&quic.Config{
				MaxIdleTimeout:          idleTimeout,
				KeepAlivePeriod:         idleTimeout / 2,
				DisablePathMTUDiscovery: true,
			}),
		)
		require.NoError(t, err)

		serverConn, err := server.Accept(ctx)
		require.NoError(t, err)

		// wait longer than the idle timeout
		time.Sleep(3 * idleTimeout)
		str, err := conn.OpenUniStream()
		require.NoError(t, err)
		_, err = str.Write([]byte("foobar"))
		require.NoError(t, err)

		// verify connection is still alive
		select {
		case <-serverConn.Context().Done():
			t.Fatal("server connection closed unexpectedly")
		default:
		}

		// idle timeout will still kick in if PINGs are dropped
		drop.Store(true)
		time.Sleep(2 * idleTimeout)
		_, err = str.Write([]byte("foobar"))
		requireIdleTimeoutError(t, err)

		// can't rely on the server connection closing, since we impose a minimum idle timeout of 5s,
		// see https://github.com/quic-go/quic-go/issues/4751
		serverConn.CloseWithError(0, "")
	})
}

func TestTimeoutAfterInactivity(t *testing.T) {
	synctest.Test(t, func(t *testing.T) {
		const idleTimeout = 15 * time.Second

		clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, time.Millisecond)
		defer closeFn(t)

		server, err := quic.Listen(
			serverPacketConn,
			getTLSConfig(),
			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
		)
		require.NoError(t, err)
		defer server.Close()

		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		defer cancel()
		counter, tr := newPacketTracer()
		conn, err := quic.Dial(
			ctx,
			clientPacketConn,
			server.Addr(),
			getTLSClientConfig(),
			getQuicConfig(&quic.Config{
				MaxIdleTimeout:          idleTimeout,
				Tracer:                  func(context.Context, bool, quic.ConnectionID) qlogwriter.Trace { return tr },
				DisablePathMTUDiscovery: true,
			}),
		)
		require.NoError(t, err)

		serverConn, err := server.Accept(ctx)
		require.NoError(t, err)
		defer serverConn.CloseWithError(0, "")

		ctx, cancel = context.WithTimeout(context.Background(), 2*idleTimeout)
		defer cancel()
		_, err = conn.AcceptStream(ctx)
		requireIdleTimeoutError(t, err)

		var lastAckElicitingPacketSentAt time.Time
		for _, p := range counter.getSentShortHeaderPackets() {
			var hasAckElicitingFrame bool
			for _, f := range p.frames {
				if _, ok := f.Frame.(qlog.AckFrame); ok {
					continue
				}
				hasAckElicitingFrame = true
				break
			}
			if hasAckElicitingFrame {
				lastAckElicitingPacketSentAt = p.time
			}
		}
		rcvdPackets := counter.getRcvdShortHeaderPackets()
		lastPacketRcvdAt := rcvdPackets[len(rcvdPackets)-1].time
		// We're ignoring here that only the first ack-eliciting packet sent resets the idle timeout.
		// This is ok since we're dealing with a lossless connection here,
		// and we'd expect to receive an ACK for additional other ack-eliciting packet sent.
		timeSinceLastAckEliciting := time.Since(lastAckElicitingPacketSentAt)
		timeSinceLastRcvd := time.Since(lastPacketRcvdAt)
		require.Equal(t, idleTimeout, max(timeSinceLastAckEliciting, timeSinceLastRcvd))

		select {
		case <-serverConn.Context().Done():
			t.Fatal("server connection closed unexpectedly")
		default:
		}
	})
}

func TestTimeoutAfterSendingPacket(t *testing.T) {
	synctest.Test(t, func(t *testing.T) {
		const idleTimeout = 15 * time.Second

		var drop atomic.Bool
		clientPacketConn, serverPacketConn, closeFn := newSimnetLinkWithRouter(t,
			time.Millisecond,
			&droppingRouter{Drop: func(p simnet.Packet) bool { return drop.Load() }},
		)
		defer closeFn(t)

		server, err := quic.Listen(
			serverPacketConn,
			getTLSConfig(),
			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
		)
		require.NoError(t, err)
		defer server.Close()

		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		defer cancel()
		conn, err := quic.Dial(
			ctx,
			clientPacketConn,
			serverPacketConn.LocalAddr(),
			getTLSClientConfig(),
			getQuicConfig(&quic.Config{MaxIdleTimeout: idleTimeout, DisablePathMTUDiscovery: true}),
		)
		require.NoError(t, err)

		serverConn, err := server.Accept(ctx)
		require.NoError(t, err)

		serverStart := time.Now()

		// wait half the idle timeout, then send a packet
		time.Sleep(idleTimeout / 2)
		drop.Store(true)

		clientStart := time.Now()
		str, err := conn.OpenUniStream()
		require.NoError(t, err)
		_, err = str.Write([]byte("foobar"))
		require.NoError(t, err)

		select {
		case <-serverConn.Context().Done():
			took := time.Since(serverStart)
			require.GreaterOrEqual(t, took, idleTimeout)
			require.Less(t, took, idleTimeout+time.Second)
		case <-time.After(2 * idleTimeout):
			t.Fatal("timeout waiting for idle timeout")
		}

		select {
		case <-conn.Context().Done():
			took := time.Since(clientStart)
			require.Equal(t, took, idleTimeout)
		case <-time.After(2 * idleTimeout):
			t.Fatal("timeout waiting for idle timeout")
		}
	})
}

type faultyConn struct {
	net.PacketConn

	MaxPackets int
	counter    atomic.Int32
}

func (c *faultyConn) ReadFrom(p []byte) (int, net.Addr, error) {
	n, addr, err := c.PacketConn.ReadFrom(p)
	counter := c.counter.Add(1)
	if counter <= int32(c.MaxPackets) {
		return n, addr, err
	}
	return 0, nil, io.ErrClosedPipe
}

func (c *faultyConn) WriteTo(p []byte, addr net.Addr) (int, error) {
	counter := c.counter.Add(1)
	if counter <= int32(c.MaxPackets) {
		return c.PacketConn.WriteTo(p, addr)
	}
	return 0, io.ErrClosedPipe
}

func TestFaultyPacketConn(t *testing.T) {
	t.Run("client", func(t *testing.T) {
		testFaultyPacketConn(t, protocol.PerspectiveClient)
	})

	t.Run("server", func(t *testing.T) {
		testFaultyPacketConn(t, protocol.PerspectiveServer)
	})
}

func testFaultyPacketConn(t *testing.T, pers protocol.Perspective) {
	t.Setenv("QUIC_GO_DISABLE_RECEIVE_BUFFER_WARNING", "true")

	synctest.Test(t, func(t *testing.T) {
		runServer := func(ln *quic.Listener) error {
			conn, err := ln.Accept(context.Background())
			if err != nil {
				return err
			}
			str, err := conn.OpenUniStream()
			if err != nil {
				return err
			}
			defer str.Close()
			_, err = str.Write(PRData)
			return err
		}

		runClient := func(conn *quic.Conn) error {
			str, err := conn.AcceptUniStream(context.Background())
			if err != nil {
				return err
			}
			data, err := io.ReadAll(str)
			if err != nil {
				return err
			}
			if !bytes.Equal(data, PRData) {
				return fmt.Errorf("wrong data: %q vs %q", data, PRData)
			}
			return conn.CloseWithError(0, "done")
		}

		clientPacketConn, serverPacketConn, closeFn := newSimnetLink(t, 100*time.Millisecond)
		defer closeFn(t)

		var cconn, sconn net.PacketConn = clientPacketConn, serverPacketConn
		maxPackets := mrand.IntN(25)
		// sanity check: sending PRData should generate at least 25 packets
		require.Greater(t, len(PRData)/1500, 25)

		t.Logf("blocking %s's connection after %d packets", pers, maxPackets)
		switch pers {
		case protocol.PerspectiveClient:
			cconn = &faultyConn{PacketConn: cconn, MaxPackets: maxPackets}
		case protocol.PerspectiveServer:
			sconn = &faultyConn{PacketConn: sconn, MaxPackets: maxPackets}
		}

		ln, err := quic.Listen(
			sconn,
			getTLSConfig(),
			getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
		)
		require.NoError(t, err)
		defer ln.Close()

		serverErrChan := make(chan error, 1)
		go func() { serverErrChan <- runServer(ln) }()

		clientErrChan := make(chan error, 1)
		go func() {
			conn, err := quic.Dial(
				context.Background(),
				cconn,
				ln.Addr(),
				getTLSClientConfig(),
				getQuicConfig(&quic.Config{DisablePathMTUDiscovery: true}),
			)
			if err != nil {
				clientErrChan <- err
				return
			}
			clientErrChan <- runClient(conn)
		}()

		var clientErr error
		select {
		case clientErr = <-clientErrChan:
		case <-time.After(time.Hour):
			t.Fatal("timeout waiting for client error")
		}
		require.Error(t, clientErr)
		if pers == protocol.PerspectiveClient {
			require.Contains(t, clientErr.Error(), io.ErrClosedPipe.Error())
		} else {
			var nerr net.Error
			require.True(t, errors.As(clientErr, &nerr))
			require.True(t, nerr.Timeout())
		}

		select {
		case serverErr := <-serverErrChan: // The handshake completed on the server side.
			require.Error(t, serverErr)
			if pers == protocol.PerspectiveServer {
				require.Contains(t, serverErr.Error(), io.ErrClosedPipe.Error())
			} else {
				var nerr net.Error
				require.True(t, errors.As(serverErr, &nerr))
				require.True(t, nerr.Timeout())
			}
		default: // The handshake didn't complete
			require.NoError(t, ln.Close())
			select {
			case <-serverErrChan:
			case <-time.After(time.Hour):
				t.Fatal("timeout waiting for server to close")
			}
		}
	})
}
